{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| default_exp models.stemgnn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# StemGNN"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The Spectral Temporal Graph Neural Network (`StemGNN`) is a Graph-based multivariate time-series forecasting model. `StemGNN` jointly learns temporal dependencies and inter-series correlations in the spectral domain, by combining Graph Fourier Transform (GFT) and Discrete Fourier Transform (DFT). \n",
    "\n",
    "This method proved state-of-the-art performance on geo-temporal datasets such as `Solar`, `METR-LA`, and `PEMS-BAY`, and \n",
    "\n",
    "**References**<br>\n",
    "-[Defu Cao, Yujing Wang, Juanyong Duan, Ce Zhang, Xia Zhu, Congrui Huang, Yunhai Tong, Bixiong Xu, Jing Bai, Jie Tong, Qi Zhang (2020). \"Spectral Temporal Graph Neural Network for Multivariate Time-series Forecasting\".](https://proceedings.neurips.cc/paper/2020/hash/cdf6581cb7aca4b7e19ef136c6e601a5-Abstract.html)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "![Figure 1. StemGNN.](imgs_models/stemgnn.png)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "from fastcore.test import test_eq\n",
    "from nbdev.showdoc import show_doc"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "\n",
    "from neuralforecast.losses.pytorch import MAE\n",
    "from neuralforecast.common._base_multivariate import BaseMultivariate"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "class GLU(nn.Module):\n",
    "    def __init__(self, input_channel, output_channel):\n",
    "        super(GLU, self).__init__()\n",
    "        self.linear_left = nn.Linear(input_channel, output_channel)\n",
    "        self.linear_right = nn.Linear(input_channel, output_channel)\n",
    "\n",
    "    def forward(self, x):\n",
    "        return torch.mul(self.linear_left(x), torch.sigmoid(self.linear_right(x)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "class StockBlockLayer(nn.Module):\n",
    "    def __init__(self, time_step, unit, multi_layer, stack_cnt=0):\n",
    "        super(StockBlockLayer, self).__init__()\n",
    "        self.time_step = time_step\n",
    "        self.unit = unit\n",
    "        self.stack_cnt = stack_cnt\n",
    "        self.multi = multi_layer\n",
    "        self.weight = nn.Parameter(\n",
    "            torch.Tensor(1, 3 + 1, 1, self.time_step * self.multi,\n",
    "                         self.multi * self.time_step))  # [K+1, 1, in_c, out_c]\n",
    "        nn.init.xavier_normal_(self.weight)\n",
    "        self.forecast = nn.Linear(self.time_step * self.multi, self.time_step * self.multi)\n",
    "        self.forecast_result = nn.Linear(self.time_step * self.multi, self.time_step)\n",
    "        if self.stack_cnt == 0:\n",
    "            self.backcast = nn.Linear(self.time_step * self.multi, self.time_step)\n",
    "        self.backcast_short_cut = nn.Linear(self.time_step, self.time_step)\n",
    "        self.relu = nn.ReLU()\n",
    "        self.GLUs = nn.ModuleList()\n",
    "        self.output_channel = 4 * self.multi\n",
    "        for i in range(3):\n",
    "            if i == 0:\n",
    "                self.GLUs.append(GLU(self.time_step * 4, self.time_step * self.output_channel))\n",
    "                self.GLUs.append(GLU(self.time_step * 4, self.time_step * self.output_channel))\n",
    "            elif i == 1:\n",
    "                self.GLUs.append(GLU(self.time_step * self.output_channel, self.time_step * self.output_channel))\n",
    "                self.GLUs.append(GLU(self.time_step * self.output_channel, self.time_step * self.output_channel))\n",
    "            else:\n",
    "                self.GLUs.append(GLU(self.time_step * self.output_channel, self.time_step * self.output_channel))\n",
    "                self.GLUs.append(GLU(self.time_step * self.output_channel, self.time_step * self.output_channel))\n",
    "\n",
    "    def spe_seq_cell(self, input):\n",
    "        batch_size, k, input_channel, node_cnt, time_step = input.size()\n",
    "        input = input.view(batch_size, -1, node_cnt, time_step)\n",
    "        ffted = torch.view_as_real(torch.fft.fft(input, dim=1))\n",
    "        real = ffted[..., 0].permute(0, 2, 1, 3).contiguous().reshape(batch_size, node_cnt, -1)\n",
    "        img = ffted[..., 1].permute(0, 2, 1, 3).contiguous().reshape(batch_size, node_cnt, -1)\n",
    "        for i in range(3):\n",
    "            real = self.GLUs[i * 2](real)\n",
    "            img = self.GLUs[2 * i + 1](img)\n",
    "        real = real.reshape(batch_size, node_cnt, 4, -1).permute(0, 2, 1, 3).contiguous()\n",
    "        img = img.reshape(batch_size, node_cnt, 4, -1).permute(0, 2, 1, 3).contiguous()\n",
    "        time_step_as_inner = torch.cat([real.unsqueeze(-1), img.unsqueeze(-1)], dim=-1)\n",
    "        iffted = torch.fft.irfft(torch.view_as_complex(time_step_as_inner), n=time_step_as_inner.shape[1], dim=1)\n",
    "        return iffted\n",
    "\n",
    "    def forward(self, x, mul_L):\n",
    "        mul_L = mul_L.unsqueeze(1)\n",
    "        x = x.unsqueeze(1)\n",
    "        gfted = torch.matmul(mul_L, x)\n",
    "        gconv_input = self.spe_seq_cell(gfted).unsqueeze(2)\n",
    "        igfted = torch.matmul(gconv_input, self.weight)\n",
    "        igfted = torch.sum(igfted, dim=1)\n",
    "        forecast_source = torch.sigmoid(self.forecast(igfted).squeeze(1))\n",
    "        forecast = self.forecast_result(forecast_source)\n",
    "        if self.stack_cnt == 0:\n",
    "            backcast_short = self.backcast_short_cut(x).squeeze(1)\n",
    "            backcast_source = torch.sigmoid(self.backcast(igfted) - backcast_short)\n",
    "        else:\n",
    "            backcast_source = None\n",
    "        return forecast, backcast_source"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "class StemGNN(BaseMultivariate):\n",
    "    \"\"\" StemGNN\n",
    "\n",
    "    The Spectral Temporal Graph Neural Network (`StemGNN`) is a Graph-based multivariate\n",
    "    time-series forecasting model. `StemGNN` jointly learns temporal dependencies and\n",
    "    inter-series correlations in the spectral domain, by combining Graph Fourier Transform (GFT)\n",
    "    and Discrete Fourier Transform (DFT). \n",
    "\n",
    "    **Parameters:**<br>\n",
    "    `h`: int, Forecast horizon. <br>\n",
    "    `input_size`: int, autorregresive inputs size, y=[1,2,3,4] input_size=2 -> y_[t-2:t]=[1,2].<br>\n",
    "    `n_series`: int, number of time-series.<br>\n",
    "    `stat_exog_list`: str list, static exogenous columns.<br>\n",
    "    `hist_exog_list`: str list, historic exogenous columns.<br>\n",
    "    `futr_exog_list`: str list, future exogenous columns.<br>\n",
    "    `n_stacks`: int=2, number of stacks in the model.<br>\n",
    "    `multi_layer`: int=5, multiplier for FC hidden size on StemGNN blocks.<br>\n",
    "    `dropout_rate`: float=0.5, dropout rate.<br>\n",
    "    `leaky_rate`: float=0.2, alpha for LeakyReLU layer on Latent Correlation layer.<br>\n",
    "    `loss`: PyTorch module, instantiated train loss class from [losses collection](https://nixtla.github.io/neuralforecast/losses.pytorch.html).<br>\n",
    "    `valid_loss`: PyTorch module=`loss`, instantiated valid loss class from [losses collection](https://nixtla.github.io/neuralforecast/losses.pytorch.html).<br>\n",
    "    `max_steps`: int=1000, maximum number of training steps.<br>\n",
    "    `learning_rate`: float=1e-3, Learning rate between (0, 1).<br>\n",
    "    `num_lr_decays`: int=-1, Number of learning rate decays, evenly distributed across max_steps.<br>\n",
    "    `early_stop_patience_steps`: int=-1, Number of validation iterations before early stopping.<br>\n",
    "    `val_check_steps`: int=100, Number of training steps between every validation loss check.<br>\n",
    "    `batch_size`: int, number of windows in each batch.<br>\n",
    "    `step_size`: int=1, step size between each window of temporal data.<br>\n",
    "    `scaler_type`: str='robust', type of scaler for temporal inputs normalization see [temporal scalers](https://nixtla.github.io/neuralforecast/common.scalers.html).<br>\n",
    "    `random_seed`: int, random_seed for pytorch initializer and numpy generators.<br>\n",
    "    `num_workers_loader`: int=os.cpu_count(), workers to be used by `TimeSeriesDataLoader`.<br>\n",
    "    `drop_last_loader`: bool=False, if True `TimeSeriesDataLoader` drops last non-full batch.<br>\n",
    "    `alias`: str, optional,  Custom name of the model.<br>\n",
    "    `optimizer`: Subclass of 'torch.optim.Optimizer', optional, user specified optimizer instead of the default choice (Adam).<br>\n",
    "    `optimizer_kwargs`: dict, optional, list of parameters used by the user specified `optimizer`.<br>\n",
    "    `lr_scheduler`: Subclass of 'torch.optim.lr_scheduler.LRScheduler', optional, user specified lr_scheduler instead of the default choice (StepLR).<br>\n",
    "    `lr_scheduler_kwargs`: dict, optional, list of parameters used by the user specified `lr_scheduler`.<br>\n",
    "    `**trainer_kwargs`: int,  keyword trainer arguments inherited from [PyTorch Lighning's trainer](https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.trainer.trainer.Trainer.html?highlight=trainer).<br>    \n",
    "    \"\"\"\n",
    "    # Class attributes\n",
    "    SAMPLING_TYPE = 'multivariate'\n",
    "    EXOGENOUS_FUTR = False\n",
    "    EXOGENOUS_HIST = False\n",
    "    EXOGENOUS_STAT = False    \n",
    "    \n",
    "    def __init__(self,\n",
    "                 h,\n",
    "                 input_size,\n",
    "                 n_series,\n",
    "                 futr_exog_list = None,\n",
    "                 hist_exog_list = None,\n",
    "                 stat_exog_list = None,\n",
    "                 n_stacks = 2,\n",
    "                 multi_layer: int = 5,\n",
    "                 dropout_rate: float = 0.5,\n",
    "                 leaky_rate: float = 0.2,\n",
    "                 loss = MAE(),\n",
    "                 valid_loss = None,\n",
    "                 max_steps: int = 1000,\n",
    "                 learning_rate: float = 1e-3,\n",
    "                 num_lr_decays: int = 3,\n",
    "                 early_stop_patience_steps: int =-1,\n",
    "                 val_check_steps: int = 100,\n",
    "                 batch_size: int = 32,\n",
    "                 step_size: int = 1,\n",
    "                 scaler_type: str = 'robust',\n",
    "                 random_seed: int = 1,\n",
    "                 num_workers_loader = 0,\n",
    "                 drop_last_loader = False,\n",
    "                 optimizer = None,\n",
    "                 optimizer_kwargs = None,\n",
    "                 lr_scheduler = None,\n",
    "                 lr_scheduler_kwargs = None,\n",
    "                 **trainer_kwargs):\n",
    "\n",
    "        # Inherit BaseMultivariate class\n",
    "        super(StemGNN, self).__init__(h=h,\n",
    "                                      input_size=input_size,\n",
    "                                      n_series=n_series,\n",
    "                                      futr_exog_list=futr_exog_list,\n",
    "                                      hist_exog_list=hist_exog_list,\n",
    "                                      stat_exog_list=stat_exog_list,                                    \n",
    "                                      loss=loss,\n",
    "                                      valid_loss=valid_loss,\n",
    "                                      max_steps=max_steps,\n",
    "                                      learning_rate=learning_rate,\n",
    "                                      num_lr_decays=num_lr_decays,\n",
    "                                      early_stop_patience_steps=early_stop_patience_steps,\n",
    "                                      val_check_steps=val_check_steps,\n",
    "                                      batch_size=batch_size,\n",
    "                                      step_size=step_size,\n",
    "                                      scaler_type=scaler_type,\n",
    "                                      num_workers_loader=num_workers_loader,\n",
    "                                      drop_last_loader=drop_last_loader,\n",
    "                                      random_seed=random_seed,\n",
    "                                      optimizer=optimizer,\n",
    "                                      optimizer_kwargs=optimizer_kwargs,\n",
    "                                      lr_scheduler=lr_scheduler,\n",
    "                                      lr_scheduler_kwargs=lr_scheduler_kwargs,\n",
    "                                      **trainer_kwargs)\n",
    "        # Quick fix for now, fix the model later.\n",
    "        if n_stacks != 2:\n",
    "            raise Exception(\"StemGNN currently only supports n_stacks=2.\")\n",
    "\n",
    "        self.unit = n_series\n",
    "        self.stack_cnt = n_stacks\n",
    "        self.alpha = leaky_rate\n",
    "        self.time_step = input_size\n",
    "        self.horizon = h\n",
    "        self.h = h\n",
    "\n",
    "        self.weight_key = nn.Parameter(torch.zeros(size=(self.unit, 1)))\n",
    "        nn.init.xavier_uniform_(self.weight_key.data, gain=1.414)\n",
    "        self.weight_query = nn.Parameter(torch.zeros(size=(self.unit, 1)))\n",
    "        nn.init.xavier_uniform_(self.weight_query.data, gain=1.414)\n",
    "        self.GRU = nn.GRU(self.time_step, self.unit)\n",
    "        self.multi_layer = multi_layer\n",
    "        self.stock_block = nn.ModuleList()\n",
    "        self.stock_block.extend(\n",
    "            [StockBlockLayer(self.time_step, self.unit, self.multi_layer, stack_cnt=i) for i in range(self.stack_cnt)])\n",
    "        self.fc = nn.Sequential(\n",
    "            nn.Linear(int(self.time_step), int(self.time_step)),\n",
    "            nn.LeakyReLU(),\n",
    "            nn.Linear(int(self.time_step), self.horizon * self.loss.outputsize_multiplier),\n",
    "        )\n",
    "        self.leakyrelu = nn.LeakyReLU(self.alpha)\n",
    "        self.dropout = nn.Dropout(p=dropout_rate)\n",
    "\n",
    "    def get_laplacian(self, graph, normalize):\n",
    "            \"\"\"\n",
    "            return the laplacian of the graph.\n",
    "            :param graph: the graph structure without self loop, [N, N].\n",
    "            :param normalize: whether to used the normalized laplacian.\n",
    "            :return: graph laplacian.\n",
    "            \"\"\"\n",
    "            if normalize:\n",
    "                D = torch.diag(torch.sum(graph, dim=-1) ** (-1 / 2))\n",
    "                L = torch.eye(graph.size(0), device=graph.device, dtype=graph.dtype) - torch.mm(torch.mm(D, graph), D)\n",
    "            else:\n",
    "                D = torch.diag(torch.sum(graph, dim=-1))\n",
    "                L = D - graph\n",
    "            return L\n",
    "\n",
    "    def cheb_polynomial(self, laplacian):\n",
    "        \"\"\"\n",
    "        Compute the Chebyshev Polynomial, according to the graph laplacian.\n",
    "        :param laplacian: the graph laplacian, [N, N].\n",
    "        :return: the multi order Chebyshev laplacian, [K, N, N].\n",
    "        \"\"\"\n",
    "        N = laplacian.size(0)  # [N, N]\n",
    "        laplacian = laplacian.unsqueeze(0)\n",
    "        first_laplacian = torch.zeros([1, N, N], device=laplacian.device, dtype=torch.float)\n",
    "        second_laplacian = laplacian\n",
    "        third_laplacian = (2 * torch.matmul(laplacian, second_laplacian)) - first_laplacian\n",
    "        forth_laplacian = 2 * torch.matmul(laplacian, third_laplacian) - second_laplacian\n",
    "        multi_order_laplacian = torch.cat([first_laplacian, second_laplacian, third_laplacian, forth_laplacian], dim=0)\n",
    "        return multi_order_laplacian\n",
    "\n",
    "    def latent_correlation_layer(self, x):\n",
    "        input, _ = self.GRU(x.permute(2, 0, 1).contiguous())\n",
    "        input = input.permute(1, 0, 2).contiguous()\n",
    "        attention = self.self_graph_attention(input)\n",
    "        attention = torch.mean(attention, dim=0)\n",
    "        degree = torch.sum(attention, dim=1)\n",
    "        # laplacian is sym or not\n",
    "        attention = 0.5 * (attention + attention.T)\n",
    "        degree_l = torch.diag(degree)\n",
    "        diagonal_degree_hat = torch.diag(1 / (torch.sqrt(degree) + 1e-7))\n",
    "        laplacian = torch.matmul(diagonal_degree_hat,\n",
    "                                    torch.matmul(degree_l - attention, diagonal_degree_hat))\n",
    "        mul_L = self.cheb_polynomial(laplacian)\n",
    "        return mul_L, attention\n",
    "\n",
    "    def self_graph_attention(self, input):\n",
    "        input = input.permute(0, 2, 1).contiguous()\n",
    "        bat, N, fea = input.size()\n",
    "        key = torch.matmul(input, self.weight_key)\n",
    "        query = torch.matmul(input, self.weight_query)\n",
    "        data = key.repeat(1, 1, N).view(bat, N * N, 1) + query.repeat(1, N, 1)\n",
    "        data = data.squeeze(2)\n",
    "        data = data.view(bat, N, -1)\n",
    "        data = self.leakyrelu(data)\n",
    "        attention = F.softmax(data, dim=2)\n",
    "        attention = self.dropout(attention)\n",
    "        return attention\n",
    "\n",
    "    def graph_fft(self, input, eigenvectors):\n",
    "        return torch.matmul(eigenvectors, input)\n",
    "\n",
    "    def forward(self, windows_batch):\n",
    "        # Parse batch\n",
    "        x = windows_batch['insample_y']\n",
    "        batch_size = x.shape[0]\n",
    "\n",
    "        mul_L, attention = self.latent_correlation_layer(x)\n",
    "        X = x.unsqueeze(1).permute(0, 1, 3, 2).contiguous()\n",
    "        result = []\n",
    "        for stack_i in range(self.stack_cnt):\n",
    "            forecast, X = self.stock_block[stack_i](X, mul_L)\n",
    "            result.append(forecast)\n",
    "        forecast = result[0] + result[1]\n",
    "        forecast = self.fc(forecast)\n",
    "\n",
    "        forecast = forecast.permute(0, 2, 1).contiguous()\n",
    "        forecast = forecast.reshape(batch_size, self.h, self.loss.outputsize_multiplier * self.n_series)\n",
    "        forecast = self.loss.domain_map(forecast)\n",
    "\n",
    "        # domain_map might have squeezed the last dimension in case n_series == 1\n",
    "        # Note that this fails in case of a tuple loss, but Multivariate does not support tuple losses yet.\n",
    "        if forecast.ndim == 2:\n",
    "            return forecast.unsqueeze(-1)\n",
    "        else:\n",
    "            return forecast"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "show_doc(StemGNN)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "show_doc(StemGNN.fit, name='StemGNN.fit')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "show_doc(StemGNN.predict, name='StemGNN.predict')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "import logging\n",
    "import warnings\n",
    "\n",
    "from neuralforecast import NeuralForecast\n",
    "from neuralforecast.utils import AirPassengersPanel, AirPassengersStatic\n",
    "from neuralforecast.losses.pytorch import MAE, MSE, RMSE, MAPE, SMAPE, MASE, relMSE, QuantileLoss, MQLoss, DistributionLoss,PMM, GMM, NBMM, HuberLoss, TukeyLoss, HuberQLoss, HuberMQLoss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "# Test losses\n",
    "logging.getLogger(\"pytorch_lightning\").setLevel(logging.ERROR)\n",
    "warnings.filterwarnings(\"ignore\")\n",
    "\n",
    "Y_train_df = AirPassengersPanel[AirPassengersPanel.ds<AirPassengersPanel['ds'].values[-12]].reset_index(drop=True) # 132 train\n",
    "Y_test_df = AirPassengersPanel[AirPassengersPanel.ds>=AirPassengersPanel['ds'].values[-12]].reset_index(drop=True) # 12 test\n",
    "\n",
    "AirPassengersStatic_single = AirPassengersStatic[AirPassengersStatic[\"unique_id\"] == 'Airline1']\n",
    "Y_train_df_single = Y_train_df[Y_train_df[\"unique_id\"] == 'Airline1']\n",
    "Y_test_df_single = Y_test_df[Y_test_df[\"unique_id\"] == 'Airline1']\n",
    "\n",
    "losses = [MAE(), MSE(), RMSE(), MAPE(), SMAPE(), MASE(seasonality=12), relMSE(y_train=Y_train_df), QuantileLoss(q=0.5), MQLoss(), DistributionLoss(distribution='Bernoulli'), DistributionLoss(distribution='Normal'), DistributionLoss(distribution='Poisson'), DistributionLoss(distribution='StudentT'), DistributionLoss(distribution='NegativeBinomial'), DistributionLoss(distribution='Tweedie'), PMM(), GMM(), NBMM(), HuberLoss(), TukeyLoss(), HuberQLoss(q=0.5), HuberMQLoss()]\n",
    "valid_losses = [MAE(), MSE(), RMSE(), MAPE(), SMAPE(), MASE(seasonality=12), relMSE(y_train=Y_train_df), QuantileLoss(q=0.5), MQLoss(), DistributionLoss(distribution='Bernoulli'), DistributionLoss(distribution='Normal'), DistributionLoss(distribution='Poisson'), DistributionLoss(distribution='StudentT'), DistributionLoss(distribution='NegativeBinomial'), DistributionLoss(distribution='Tweedie'), PMM(), GMM(), NBMM(), HuberLoss(), TukeyLoss(), HuberQLoss(q=0.5), HuberMQLoss()]\n",
    "\n",
    "for loss, valid_loss in zip(losses, valid_losses):\n",
    "    try:\n",
    "        model = StemGNN(h=12,\n",
    "                        input_size=24,\n",
    "                        n_series=2,\n",
    "                        scaler_type='robust',\n",
    "                        max_steps=2,\n",
    "                        early_stop_patience_steps=-1,\n",
    "                        val_check_steps=10,\n",
    "                        learning_rate=1e-3,\n",
    "                        loss=loss,\n",
    "                        valid_loss=valid_loss,\n",
    "                        batch_size=32\n",
    "                        )\n",
    "\n",
    "        fcst = NeuralForecast(models=[model], freq='M')\n",
    "        fcst.fit(df=Y_train_df, static_df=AirPassengersStatic, val_size=12)\n",
    "        forecasts = fcst.predict(futr_df=Y_test_df)\n",
    "    except Exception as e:\n",
    "        assert str(e) == f\"{loss} is not supported in a Multivariate model.\"\n",
    "\n",
    "\n",
    "# Test n_series = 1\n",
    "model = StemGNN(h=12,\n",
    "                input_size=24,\n",
    "                n_series=1,\n",
    "                scaler_type='robust',\n",
    "                max_steps=2,\n",
    "                early_stop_patience_steps=-1,\n",
    "                val_check_steps=10,\n",
    "                learning_rate=1e-3,\n",
    "                loss=MAE(),\n",
    "                valid_loss=MAE(),\n",
    "                batch_size=32\n",
    "                )\n",
    "fcst = NeuralForecast(models=[model], freq='M')\n",
    "fcst.fit(df=Y_train_df_single, static_df=AirPassengersStatic_single, val_size=12)\n",
    "forecasts = fcst.predict(futr_df=Y_test_df_single)        "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Usage Examples"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Train model and forecast future values with `predict` method."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| eval: false\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import pytorch_lightning as pl\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from neuralforecast import NeuralForecast\n",
    "from neuralforecast.utils import AirPassengersPanel, AirPassengersStatic\n",
    "from neuralforecast.losses.pytorch import MAE\n",
    "\n",
    "Y_train_df = AirPassengersPanel[AirPassengersPanel.ds<AirPassengersPanel['ds'].values[-12]].reset_index(drop=True) # 132 train\n",
    "Y_test_df = AirPassengersPanel[AirPassengersPanel.ds>=AirPassengersPanel['ds'].values[-12]].reset_index(drop=True) # 12 test\n",
    "\n",
    "model = StemGNN(h=12,\n",
    "                input_size=24,\n",
    "                n_series=2,\n",
    "                stat_exog_list=['airline1'],\n",
    "                futr_exog_list=['trend'],\n",
    "                scaler_type='robust',\n",
    "                max_steps=500,\n",
    "                early_stop_patience_steps=-1,\n",
    "                val_check_steps=10,\n",
    "                learning_rate=1e-3,\n",
    "                loss=MAE(),\n",
    "                valid_loss=None,\n",
    "                batch_size=32\n",
    "                )\n",
    "\n",
    "fcst = NeuralForecast(models=[model], freq='M')\n",
    "fcst.fit(df=Y_train_df, static_df=AirPassengersStatic, val_size=12)\n",
    "forecasts = fcst.predict(futr_df=Y_test_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| eval: false\n",
    "# Plot predictions\n",
    "fig, ax = plt.subplots(1, 1, figsize = (20, 7))\n",
    "Y_hat_df = forecasts.reset_index(drop=False).drop(columns=['unique_id','ds'])\n",
    "plot_df = pd.concat([Y_test_df, Y_hat_df], axis=1)\n",
    "plot_df = pd.concat([Y_train_df, plot_df])\n",
    "\n",
    "plot_df = plot_df[plot_df.unique_id=='Airline1'].drop('unique_id', axis=1)\n",
    "plt.plot(plot_df['ds'], plot_df['y'], c='black', label='True')\n",
    "plt.plot(plot_df['ds'], plot_df['StemGNN'], c='blue', label='Forecast')\n",
    "ax.set_title('AirPassengers Forecast', fontsize=22)\n",
    "ax.set_ylabel('Monthly Passengers', fontsize=20)\n",
    "ax.set_xlabel('Year', fontsize=20)\n",
    "ax.legend(prop={'size': 15})\n",
    "ax.grid()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Using `cross_validation` to forecast multiple historic values."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| eval: false\n",
    "fcst = NeuralForecast(models=[model], freq='M')\n",
    "forecasts = fcst.cross_validation(df=AirPassengersPanel, static_df=AirPassengersStatic, n_windows=2, step_size=12)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| eval: false\n",
    "# Plot predictions\n",
    "fig, ax = plt.subplots(1, 1, figsize = (20, 7))\n",
    "Y_hat_df = forecasts.loc['Airline1']\n",
    "Y_df = AirPassengersPanel[AirPassengersPanel['unique_id']=='Airline1']\n",
    "\n",
    "plt.plot(Y_df['ds'], Y_df['y'], c='black', label='True')\n",
    "plt.plot(Y_hat_df['ds'], Y_hat_df['StemGNN'], c='blue', label='Forecast')\n",
    "ax.set_title('AirPassengers Forecast', fontsize=22)\n",
    "ax.set_ylabel('Monthly Passengers', fontsize=20)\n",
    "ax.set_xlabel('Year', fontsize=20)\n",
    "ax.legend(prop={'size': 15})\n",
    "ax.grid()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "python3",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
